import tensorflow as tf
from tensorflow.keras.layers import Layer


class MultiHeadAttention(Layer):
    def __init__(
            self,
            heads,
            head_size,
            out_dim=None,
            use_bias=True,
    ):
        super(MultiHeadAttention, self).__init__()
        self.heads = heads
        self.head_size = head_size
        self.out_dim = out_dim
        self.use_bias = use_bias

    def build(self, input_shape):
        super(MultiHeadAttention, self).build(input_shape)
        self.q_dense = tf.keras.layers.Dense(
            units=self.head_size * self.heads,
            use_bias=self.use_bias,
        )
        self.k_dense = tf.keras.layers.Dense(
            units=self.head_size * self.heads,
            use_bias=self.use_bias,
        )
        self.v_dense = tf.keras.layers.Dense(
            units=self.head_size * self.heads,
            use_bias=self.use_bias,
        )
        self.o_dense = tf.keras.layers.Dense(
            units=self.out_dim,
            use_bias=self.use_bias,
        )

    def call(self, inputs):
        q = inputs
        k = inputs
        v = inputs
        # 线性变化
        qw = self.q_dense(q)
        kw = self.k_dense(k)
        vw = self.v_dense(v)
        # 形状变换
        qw = tf.reshape(qw, (-1, tf.shape(q)[1], self.heads, self.head_size))
        kw = tf.reshape(kw, (-1, tf.shape(q)[1], self.heads, self.head_size))
        vw = tf.reshape(vw, (-1, tf.shape(q)[1], self.heads, self.head_size))
        # attention
        qkv_inputs = [qw, kw, vw]
        o = self.pay_attention_to(qkv_inputs)
        o = tf.reshape(o, (-1, tf.shape(o)[1], self.head_size * self.heads))
        o = self.o_dense(o)
        return o

    def pay_attention_to(self, inputs):
        (qw, kw, vw) = inputs[:3]
        a = tf.einsum('bjhd,bkhd->bhjk', qw, kw)
        a = a / self.head_size ** 0.5
        A = tf.nn.softmax(a)
        o = tf.einsum('bhjk,bkhd -> bjhd', A, vw)
        return o